import torch
import os
import torchvision
from torchvision.transforms import v2
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.nn as nn
from PIL import Image, ImageOps
import torch
import pdb
import numpy as np
import yaml
from tqdm import tqdm
import sys
import matplotlib.pyplot as plt
sys.path.append('..')
from template import utils
from torchvision.utils import save_image
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore")
# setting config
config = yaml.safe_load(open("config.yaml"))
batch_size = int(config["BATCH_SIZE"])
print(f"Our config: {config}")
Our config: {'BATCH_SIZE': 64, 'NUM_EPOCHS': 10, 'LR': '3e-4'}
transform = transforms.Compose([
transforms.ToTensor(),
v2.Resize((128, 128))
])
train_dataset = torchvision.datasets.CelebA(root='./data', split='train',
download=True, transform=transform)
valid_dataset = torchvision.datasets.CelebA(root='./data', split='valid',
download=True, transform=transform)
test_dataset = torchvision.datasets.CelebA(root='./data', split='test',
download=True, transform=transform)
#create dataloaders
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
shuffle=False, num_workers=2)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size,
shuffle=False, num_workers=2)
Files already downloaded and verified Files already downloaded and verified Files already downloaded and verified
imgs, labels = next(iter(testloader))
print(f"Image Shapes: {imgs.shape}")
print(f"Label Shapes: {labels.shape}")
Image Shapes: torch.Size([64, 3, 128, 128]) Label Shapes: torch.Size([64, 40])
N_IMGS = 8
fig, ax = plt.subplots(1,N_IMGS)
fig.set_size_inches(3 * N_IMGS, 3)
ids = np.random.randint(low=0, high=len(train_dataset), size=N_IMGS)
for i, n in enumerate(ids):
img = train_dataset[n][0].numpy().reshape(3,128,128).transpose(1, 2, 0)
ax[i].imshow(img)
#ax[i].set_title(f"Img #{n} Label: {train_dataset[n][1]}")
#ax[i].axis("off")
plt.show()
def save_model(model, optimizer, epoch, stats, exp_no = 40120242):
""" Saving model checkpoint """
if(not os.path.exists("experiments/experiment_"+str(exp_no)+"/models")):
os.makedirs("experiments/experiment_"+str(exp_no)+"/models")
savepath = "experiments/experiment_"+str(exp_no)+f"/models/checkpoint_epoch_{epoch}.pth"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'stats': stats
}, savepath)
return
def load_model(model, optimizer, savepath):
""" Loading pretrained checkpoint """
checkpoint = torch.load(savepath)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint["epoch"]
stats = checkpoint["stats"]
return model, optimizer, epoch, stats
def train_epoch(model, train_loader, optimizer, criterion, epoch, device, lambda_kld = 1e-03):
""" Training a model for one epoch """
loss_list = []
recons_loss = []
vae_loss = []
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
for i, (images, _) in progress_bar:
images = images.to(device)
# Clear gradients w.r.t. parameters
optimizer.zero_grad()
# Forward pass
recons, (z, mu, log_var) = model(images)
# Calculate Loss
loss, (mse, kld) = criterion(recons, images, mu, log_var, lambda_kld)
loss_list.append(loss.item())
recons_loss.append(mse.item())
vae_loss.append(kld.item())
# Getting gradients w.r.t. parameters
loss.backward()
# Updating parameters
optimizer.step()
progress_bar.set_description(f"Epoch {epoch+1} Iter {i+1}: loss {loss.item():.5f}. ")
mean_loss = np.mean(loss_list)
return mean_loss, loss_list
@torch.no_grad()
def eval_model(model, eval_loader, criterion, device, epoch=None, savefig=False, savepath="", writer=None, lambda_kld = 1e-03):
""" Evaluating the model for either validation or test """
loss_list = []
recons_loss = []
kld_loss = []
for i, (images, _) in enumerate(eval_loader):
images = images.to(device)
# Forward pass
recons, (z, mu, log_var) = model(images)
loss, (mse, kld) = criterion(recons, images, mu, log_var, lambda_kld)
loss_list.append(loss.item())
recons_loss.append(mse.item())
kld_loss.append(kld.item())
if(i==0 and savefig):
save_image(recons[:64].cpu(), os.path.join(savepath, f"recons{epoch}.png"))
# Total correct predictions and loss
loss = np.mean(loss_list)
recons_loss = np.mean(recons_loss)
kld_loss = np.mean(kld_loss)
return loss, recons_loss, kld_loss
def train_model(model, optimizer, scheduler, criterion, train_loader, valid_loader,
num_epochs, writer,savepath="", save_frequency=2,lambda_kld = 1e-03):
""" Training a model for a given number of epochs"""
train_loss = []
val_loss = []
val_loss_recons = []
val_loss_kld = []
loss_iters = []
for epoch in range(num_epochs):
# validation epoch
model.eval() # important for dropout and batch norms
log_epoch = (epoch % save_frequency == 0 or epoch == num_epochs - 1)
loss, recons_loss, kld_loss = eval_model(
model=model, eval_loader=valid_loader, criterion=criterion,
device=device, epoch=epoch, savefig=log_epoch, savepath=savepath,
writer=writer, lambda_kld = lambda_kld
)
val_loss.append(loss)
val_loss_recons.append(recons_loss)
val_loss_kld.append(kld_loss)
# training epoch
model.train() # important for dropout and batch norms
mean_loss, cur_loss_iters = train_epoch(
model=model, train_loader=train_loader, optimizer=optimizer,
criterion=criterion, epoch=epoch, device=device, lambda_kld = lambda_kld
)
# PLATEAU SCHEDULER
scheduler.step(val_loss[-1])
train_loss.append(mean_loss)
loss_iters = loss_iters + cur_loss_iters
if(epoch % save_frequency == 0):
stats = {
"train_loss": train_loss,
"valid_loss": val_loss,
"loss_iters": loss_iters
}
save_model(model=model, optimizer=optimizer, epoch=epoch, stats=stats)
if(log_epoch):
print(f" Train loss: {round(mean_loss, 5)}")
print(f" Valid loss: {round(loss, 5)}")
print(f" Valid loss recons: {round(val_loss_recons[-1], 5)}")
print(f" Valid loss KL-D: {round(val_loss_kld[-1], 5)}")
print(f"Training completed")
return train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld
In the following we will outline some obvservations made when choosing the architecture of our ConvVAE. When employing a kernel size of 2, the images exhibited a blocky appearance, prompting a switch to a kernel size of 3 with a smaller stride, resulting in more satisfactory outcomes by mitigating the blockiness. Attempts to enhance the model by increasing the number of linear layers in the encoder led to undesired brownish artifacts. Addressing this issue with supplementary convolutional operations at the end of the network helped alleviate the problem. The introduction of dropouts along with additional convolutional layers resulted in images having a brownish tint, despite an increase in details.
class ConvVAE(nn.Module):
def __init__(self):
super(ConvVAE, self).__init__()
# Define Convolutional Encoders
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels = 8, kernel_size = 3, padding = 1),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2),
nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = 3, padding = 1),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2),
nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2),
nn.Flatten(),
)
# Define mean and variance
self.mu = nn.Linear(8192, 200)
# Note: we learn the log variance to make training easier (allows negative values)
self.log_var = nn.Linear(8192, 200)
# Define decoder
self.decoder = nn.Sequential(
nn.Linear(200, 8192),
nn.ReLU(),
nn.Unflatten(dim = 1, unflattened_size=(32, 16, 16)),
nn.ConvTranspose2d(in_channels = 32, out_channels=32, kernel_size = 3, stride = 2, padding=0),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=32, out_channels = 16, kernel_size = 3, stride = 2, padding = 0),
nn.ReLU(),
nn.ConvTranspose2d(in_channels = 16, out_channels = 16, kernel_size = 3, stride = 2, padding = 0),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=8, stride = 1, padding = 0),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, stride = 1, padding = 1),
)
def reparameterize(self, mu, log_var):
""" Reparametrization trick"""
std = torch.exp(0.5*log_var)
eps = torch.randn_like(std) # random sampling happens here
z = mu + std * eps
return z
def forward(self, x):
x = self.encoder(x)
mean = self.mu(x)
log_var = self.log_var(x)
z = self.reparameterize(mean, log_var)
x_hat = self.decoder(z)
return x_hat, (z, mean, log_var)
def vae_loss_function(recons, target, mu, log_var, lambda_kld=1e-3):
"""
Combined loss function for joint optimization of
reconstruction and ELBO
"""
recons_loss = F.mse_loss(recons, target)
# Deriving kld for vaes: https://stats.stackexchange.com/questions/318748/deriving-the-kl-divergence-loss-for-vaes
kld = (-0.5 * (1 + log_var - mu**2 - log_var.exp()).sum(dim=1)).mean(dim=0) # closed-form solution of KLD in Gaussian
loss = recons_loss + lambda_kld * kld
return loss, (recons_loss, kld)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cvae = ConvVAE()
criterion = vae_loss_function
optimizer = torch.optim.Adam(cvae.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 3, factor = 0.5, verbose = True)
cvae = cvae.to(device)
savepath = "/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_9"
'''train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(
model=cvae, optimizer=optimizer, scheduler=scheduler, criterion=vae_loss_function,
train_loader=trainloader, valid_loader=validloader, num_epochs=15, savepath=savepath, writer=None)'''
'train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(\n model=cvae, optimizer=optimizer, scheduler=scheduler, criterion=vae_loss_function,\n train_loader=trainloader, valid_loader=validloader, num_epochs=15, savepath=savepath, writer=None)'
cvae = ConvVAE().to(device)
cvae, optimizer, epoch, stats = load_model(cvae, optimizer, '/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_8/models/checkpoint_epoch_48.pth')
utils.plot_loss_epoch(stats['train_loss'][:48], stats['valid_loss'][1:])
recons_loss = stats['other_loss_stats'][0]
Kl_d_loss = stats['other_loss_stats'][1]
epochs = range(1, len(Kl_d_loss) + 1)
plt.figure(figsize=(10, 5))
plt.plot(epochs, Kl_d_loss, 'bo-', label='KLD Loss')
plt.title('KLD Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
plt.figure(figsize=(10, 5))
plt.plot(epochs, recons_loss, 'bo-', label='Reconstruction Loss')
plt.title('Reconstruction Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
Most of the loss development occurs in the earlier steps. We note that we omitted the first step in this plot to make the later changes in loss visible in the figure. Later on we only see minimal improvement. This is also reflected by the quality of the reconstruction images of different epochs. While the main changes (i.e. face becomes visible, blockiness decrease) occurs in earlier epochs. Later epochs are mostly making the reconstructions more detailed. The KL-Divergence converges quickly to a value around 10.5. As expected, given the larger weighting the Reconstruction loss looks very similar to the train/validation loss.
# Generate more images
with torch.no_grad():
for i in range(5):
z = torch.randn(64, 200).to(device)
sample = cvae.decoder(z)
recons = sample.view(64, 3, 128, 128).cpu()
fig, axes = plt.subplots(1, 10, figsize=(128, 128)) # Adjust figsize as needed
for i in range(10):
img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
axes[i].imshow(img)
axes[i].axis('off') # Turn off axis labels for clarity
plt.tight_layout()
plt.show()
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)
with torch.no_grad():
sample, _ = cvae(test_data)
recons = sample.view(batch_size, 3, 128, 128).cpu()
fig, axes = plt.subplots(2, 10, figsize=(128, 128)) # Adjust figsize as needed
for i in range(10):
img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
# test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
axes[0][i].imshow(test_img)
axes[0][i].axis('off')
axes[1][i].imshow(img)
axes[1][i].axis('off') # Turn off axis labels for clarity
plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
The reconstruction results of the standard ConvVAE are decent. The faces are clearly visible and some definining features of the reconstrucions are visible as well. One downside is, that the faces look very similar like an "average" of all faces. This is expected however and one of the reasons VQVAE and resulting architectures gained more popularity. Similar comments can be made about the generation results. We do believe that improving our architecture even further would lead to even better results.
We decided to use ResNet18 pretrained encoder. The class below represents a basic block, which resnet 18 consists of. One important difference between this basic block and the one used in original resnet, is that here, we use nn.ConvTranspose2d instead nn.Conv2d. This is because we use it in the decoder, which is supposed to be a mirrored version of an encoder. Generally it is true, architecture is the same in the encoder and the decoder except number of filters in the decoder layers. The number(s) needed to be changed because, without doing so, final reconstructed image dimensions would be different than expected (3, 128, 128). Of course, this issue can be fixed by resizing the final image, but this created weird artifacts in the reconstructed image. For finetuning we chose end-to-end finetuning, given its a simple choice, guaranteeing decent results.
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=2, transpose=False):
super(BasicBlock, self).__init__()
self.transpose = transpose
output_padding = 1 if stride > 1 else 0
self.conv1 = nn.ConvTranspose2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, \
output_padding=output_padding, bias=False)
self.conv2 = nn.ConvTranspose2d(planes, planes, kernel_size=3, stride=1, padding=1, \
output_padding=0, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.ConvTranspose2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, output_padding=output_padding, bias=False),
# nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x))) #self.bn1()
out = self.bn2(self.conv2(out)) #self.bn2()
out += self.shortcut(x)
out = F.relu(out)
return out
class Decoder(nn.Module):
def __init__(self, block, num_blocks, latent_dim):
super(Decoder, self).__init__()
self.in_planes = 64*block.expansion
self.fc2 = nn.Linear(latent_dim, 4096)
self.layer1 = self._make_layer(block, 32, num_blocks[3], stride=2, transpose=True) #filter numbers are decreased but architecture remains the same
self.layer2 = self._make_layer(block, 16, num_blocks[2], stride=2, transpose=True)
self.layer3 = self._make_layer(block, 8, num_blocks[1], stride=2, transpose=True)
self.layer4 = self._make_layer(block, 8, num_blocks[0], stride= 1, transpose=True)
self.conv1 = nn.ConvTranspose2d(8, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(3)
def _make_layer(self, block, planes, num_blocks, stride, transpose):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in reversed(strides):
layers.append(block(self.in_planes, planes, stride, transpose))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.fc2(x))
out = out.view(out.size(0),64,8,8) # reshape output of linear layer
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.relu(self.bn1(self.conv1(out)))
return out
def DecoderResNet18(latent_dim):
return Decoder(BasicBlock, [2,2,2,2], latent_dim)
class ResNetVAE(nn.Module):
def __init__(self, latent_dim = 512):
super(ResNetVAE, self).__init__()
resnet = torchvision.models.resnet18(weights='DEFAULT')
self.encoder = resnet
# Note: 1000 marks the output dim of resnet
self.mu = nn.Linear(1000, latent_dim)
# Note: we learn the log variance to make training easier (allows negative values)
self.log_var = nn.Linear(1000, latent_dim)
self.decoder = DecoderResNet18(latent_dim)
def reparameterize(self, mu, log_var):
""" Reparametrization trick"""
std = torch.exp(0.5*log_var)
eps = torch.randn_like(std) # random sampling happens here
z = mu + std * eps
return z
def forward(self, x):
x = self.encoder(x)
mu = self.mu(x)
log_var = self.log_var(x)
z = self.reparameterize(mu, log_var)
x_hat = self.decoder(z)
return x_hat, (z, mu, log_var)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
resnetvae = ResNetVAE(512)
criterion = vae_loss_function
optimizer = torch.optim.Adam(resnetvae.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 4, factor = 0.5, verbose = True)
resnetvae = resnetvae.to(device)
num_epochs = 50
device
'cuda'
train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(
model=resnetvae, optimizer=optimizer, scheduler=scheduler, criterion=vae_loss_function,
train_loader=trainloader, valid_loader=validloader, num_epochs=num_epochs, writer=None,lambda_kld=1e-4)
Epoch 1 Iter 2544: loss 0.03966. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [02:39<00:00, 15.94it/s]
Train loss: 0.07908
Valid loss: 0.57485
Valid loss recons: 0.28432
Valid loss KL-D: 2905.30506
Epoch 2 Iter 2544: loss 0.02919. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.56it/s] Epoch 3 Iter 2544: loss 0.02737. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.29it/s]
Train loss: 0.02353
Valid loss: 0.02543
Valid loss recons: 0.0209
Valid loss KL-D: 45.28074
Epoch 4 Iter 2544: loss 0.02349. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.53it/s] Epoch 5 Iter 2544: loss 0.02303. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Train loss: 0.02179
Valid loss: 0.02297
Valid loss recons: 0.01813
Valid loss KL-D: 48.41507
Epoch 6 Iter 2544: loss 0.02189. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s] Epoch 7 Iter 2544: loss 0.02221. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.41it/s]
Train loss: 0.02095
Valid loss: 0.02095
Valid loss recons: 0.01612
Valid loss KL-D: 48.29697
Epoch 8 Iter 2544: loss 0.02217. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s] Epoch 9 Iter 2544: loss 0.02239. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.49it/s]
Train loss: 0.02037
Valid loss: 0.0205
Valid loss recons: 0.01553
Valid loss KL-D: 49.74597
Epoch 10 Iter 2544: loss 0.01970. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s] Epoch 11 Iter 2544: loss 0.01845. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.49it/s]
Train loss: 0.02
Valid loss: 0.01985
Valid loss recons: 0.01494
Valid loss KL-D: 49.08961
Epoch 12 Iter 2544: loss 0.01988. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.29it/s] Epoch 13 Iter 2544: loss 0.02589. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.51it/s]
Train loss: 0.02088
Valid loss: 0.0196
Valid loss recons: 0.01457
Valid loss KL-D: 50.2956
Epoch 14 Iter 2544: loss 0.01909. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.41it/s] Epoch 15 Iter 2544: loss 0.02095. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s]
Train loss: 0.01976
Valid loss: 0.01961
Valid loss recons: 0.01432
Valid loss KL-D: 52.85515
Epoch 16 Iter 2544: loss 0.01989. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.46it/s] Epoch 17 Iter 2544: loss 0.02059. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Train loss: 0.01954
Valid loss: 0.01948
Valid loss recons: 0.01435
Valid loss KL-D: 51.30299
Epoch 18 Iter 2544: loss 0.02027. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s] Epoch 19 Iter 2544: loss 0.02018. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.51it/s]
Train loss: 0.01958
Valid loss: 0.01942
Valid loss recons: 0.01421
Valid loss KL-D: 52.10604
Epoch 20 Iter 2544: loss 0.01931. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s] Epoch 21 Iter 2544: loss 0.02264. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s]
Train loss: 0.01934
Valid loss: 0.01925
Valid loss recons: 0.01401
Valid loss KL-D: 52.39189
Epoch 22 Iter 2544: loss 0.02076. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.54it/s] Epoch 23 Iter 2544: loss 0.01886. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Train loss: 0.01924
Valid loss: 0.01923
Valid loss recons: 0.01397
Valid loss KL-D: 52.6373
Epoch 24 Iter 2544: loss 0.01902. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.26it/s] Epoch 25 Iter 2544: loss 0.01988. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.48it/s]
Train loss: 0.01916
Valid loss: 0.01934
Valid loss recons: 0.01399
Valid loss KL-D: 53.5481
Epoch 26 Iter 2544: loss 0.02002. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.39it/s] Epoch 27 Iter 2544: loss 0.02156. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.27it/s]
Train loss: 0.01908
Valid loss: 0.019
Valid loss recons: 0.01366
Valid loss KL-D: 53.37492
Epoch 28 Iter 2544: loss 0.01837. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s] Epoch 29 Iter 2544: loss 0.02175. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Train loss: 0.01902
Valid loss: 0.01911
Valid loss recons: 0.01375
Valid loss KL-D: 53.59315
Epoch 30 Iter 2544: loss 0.01916. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.43it/s] Epoch 31 Iter 2544: loss 0.02035. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.51it/s]
Train loss: 0.01897
Valid loss: 0.01894
Valid loss recons: 0.01351
Valid loss KL-D: 54.28299
Epoch 32 Iter 2544: loss 0.02010. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.57it/s] Epoch 33 Iter 2544: loss 0.02093. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Train loss: 0.01891
Valid loss: 0.01885
Valid loss recons: 0.01343
Valid loss KL-D: 54.13779
Epoch 34 Iter 2544: loss 0.01982. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.44it/s] Epoch 35 Iter 2544: loss 0.02034. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
Train loss: 0.01885
Valid loss: 0.01887
Valid loss recons: 0.01341
Valid loss KL-D: 54.5866
Epoch 36 Iter 2544: loss 0.02112. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.43it/s] Epoch 37 Iter 2544: loss 0.01945. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s]
Train loss: 0.01881
Valid loss: 0.01879
Valid loss recons: 0.0135
Valid loss KL-D: 52.97335
Epoch 38 Iter 2544: loss 0.01626. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.49it/s] Epoch 39 Iter 2544: loss 0.01797. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.46it/s]
Train loss: 0.01878
Valid loss: 0.01879
Valid loss recons: 0.0134
Valid loss KL-D: 53.91221
Epoch 40 Iter 2544: loss 0.02491. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.50it/s] Epoch 41 Iter 2544: loss 0.01838. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.28it/s]
Train loss: 0.01874
Valid loss: 0.01876
Valid loss recons: 0.01336
Valid loss KL-D: 54.01153
Epoch 42 Iter 2544: loss 0.02161. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s] Epoch 43 Iter 2544: loss 0.01922. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.24it/s]
Train loss: 0.01885
Valid loss: 0.01864
Valid loss recons: 0.01319
Valid loss KL-D: 54.45878
Epoch 44 Iter 2544: loss 0.01805. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.53it/s] Epoch 45 Iter 2544: loss 0.01930. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Train loss: 0.01869
Valid loss: 0.01871
Valid loss recons: 0.01325
Valid loss KL-D: 54.59709
Epoch 46 Iter 2544: loss 0.01862. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.55it/s] Epoch 47 Iter 2544: loss 0.01874. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Train loss: 0.01868
Valid loss: 0.01912
Valid loss recons: 0.01363
Valid loss KL-D: 54.88449
Epoch 48 Iter 2544: loss 0.02108. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.48it/s] Epoch 49 Iter 2544: loss 0.01891. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:46<00:00, 23.86it/s]
Epoch 00049: reducing learning rate of group 0 to 5.0000e-04.
Train loss: 0.01864
Valid loss: 0.01892
Valid loss recons: 0.01329
Valid loss KL-D: 56.22169
Epoch 50 Iter 2544: loss 0.02033. : 100%|████████████████████████████████████████████████████████| 2544/2544 [08:38<00:00, 4.91it/s]
Train loss: 0.01845
Valid loss: 0.01856
Valid loss recons: 0.01306
Valid loss KL-D: 54.98609
Training completed
utils.plot_loss_epoch(train_loss, val_loss)
resnetvae = ResNetVAE(512)
optimizer = torch.optim.Adam(resnetvae.parameters(), lr=1e-3)
_, _, _, stats = load_model(resnetvae, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120241/models/checkpoint_epoch_48.pth')
# Generate more images
with torch.no_grad():
for i in range(5):
z = torch.randn(64, 512).to(device)
sample = resnetvae.decoder(z)
recons = sample.view(64, 3, 128, 128).cpu()
fig, axes = plt.subplots(1, 10, figsize=(128, 128)) # Adjust figsize as needed
for i in range(10):
img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
axes[i].imshow(img)
axes[i].axis('off') # Turn off axis labels for clarity
plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)
with torch.no_grad():
sample, _ = resnetvae(test_data)
recons = sample.view(batch_size, 3, 128, 128).cpu()
fig, axes = plt.subplots(2, 10, figsize=(128, 128)) # Adjust figsize as needed
for i in range(10):
img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
# test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
axes[0][i].imshow(test_img)
axes[0][i].axis('off')
axes[1][i].imshow(img)
axes[1][i].axis('off') # Turn off axis labels for clarity
plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Let's check the influence of the KL divergence lambda parameter by increasing it by the factor of 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
resnetvae = ResNetVAE(512)
criterion = vae_loss_function
optimizer = torch.optim.Adam(resnetvae.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 4, factor = 0.5, verbose = True)
resnetvae = resnetvae.to(device)
train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(
model=resnetvae, optimizer=optimizer, scheduler=scheduler, criterion=vae_loss_function,
train_loader=trainloader, valid_loader=validloader, num_epochs=num_epochs, writer=None,lambda_kld=1e-3)
Epoch 1 Iter 2544: loss 0.05426. : 100%|████████████████████████████████████████████| 2544/2544 [07:37<00:00, 5.56it/s]
Train loss: 0.07213
Valid loss: 2.92893
Valid loss recons: 0.28448
Valid loss KL-D: 2644.44034
Epoch 2 Iter 2544: loss 0.04735. : 100%|████████████████████████████████████████████| 2544/2544 [02:56<00:00, 14.38it/s] Epoch 3 Iter 2544: loss 0.03960. : 100%|████████████████████████████████████████████| 2544/2544 [02:51<00:00, 14.81it/s]
Train loss: 0.04312
Valid loss: 0.0437
Valid loss recons: 0.03463
Valid loss KL-D: 9.06563
Epoch 4 Iter 2544: loss 0.03926. : 100%|████████████████████████████████████████████| 2544/2544 [01:58<00:00, 21.47it/s] Epoch 5 Iter 2544: loss 0.04034. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.33it/s]
Train loss: 0.04175
Valid loss: 0.04172
Valid loss recons: 0.03235
Valid loss KL-D: 9.37758
Epoch 6 Iter 2544: loss 0.04283. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.35it/s] Epoch 7 Iter 2544: loss 0.04417. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s]
Train loss: 0.0406
Valid loss: 0.04049
Valid loss recons: 0.03125
Valid loss KL-D: 9.23116
Epoch 8 Iter 2544: loss 0.04002. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s] Epoch 9 Iter 2544: loss 0.04285. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s]
Train loss: 0.04022
Valid loss: 0.04003
Valid loss recons: 0.03053
Valid loss KL-D: 9.50145
Epoch 10 Iter 2544: loss 0.03665. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.26it/s] Epoch 11 Iter 2544: loss 0.04062. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s]
Train loss: 0.03988
Valid loss: 0.03968
Valid loss recons: 0.03
Valid loss KL-D: 9.68452
Epoch 12 Iter 2544: loss 0.03864. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s] Epoch 13 Iter 2544: loss 0.04046. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.24it/s]
Train loss: 0.0397
Valid loss: 0.03939
Valid loss recons: 0.03
Valid loss KL-D: 9.38439
Epoch 14 Iter 2544: loss 0.03955. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.25it/s] Epoch 15 Iter 2544: loss 0.03783. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
Train loss: 0.03959
Valid loss: 0.03923
Valid loss recons: 0.02957
Valid loss KL-D: 9.65668
Epoch 16 Iter 2544: loss 0.04080. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.31it/s] Epoch 17 Iter 2544: loss 0.03856. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.33it/s]
Train loss: 0.03946
Valid loss: 0.03912
Valid loss recons: 0.02938
Valid loss KL-D: 9.73264
Epoch 18 Iter 2544: loss 0.03900. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.25it/s] Epoch 19 Iter 2544: loss 0.06096. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
Train loss: 0.03935
Valid loss: 0.03908
Valid loss recons: 0.02913
Valid loss KL-D: 9.94453
Epoch 20 Iter 2544: loss 0.04351. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.31it/s] Epoch 21 Iter 2544: loss 0.03699. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.39it/s]
Train loss: 0.03922
Valid loss: 0.03893
Valid loss recons: 0.02917
Valid loss KL-D: 9.75982
Epoch 22 Iter 2544: loss 0.04202. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.42it/s] Epoch 23 Iter 2544: loss 0.04551. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.31it/s]
Train loss: 0.0391
Valid loss: 0.03891
Valid loss recons: 0.02908
Valid loss KL-D: 9.83609
Epoch 24 Iter 2544: loss 0.03520. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.28it/s] Epoch 25 Iter 2544: loss 0.03931. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.33it/s]
Train loss: 0.03894
Valid loss: 0.03877
Valid loss recons: 0.02879
Valid loss KL-D: 9.98364
Epoch 26 Iter 2544: loss 0.03959. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.26it/s] Epoch 27 Iter 2544: loss 0.04638. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
Train loss: 0.03886
Valid loss: 0.03864
Valid loss recons: 0.02846
Valid loss KL-D: 10.18017
Epoch 28 Iter 2544: loss 0.03527. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s] Epoch 29 Iter 2544: loss 0.03818. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s]
Train loss: 0.03883
Valid loss: 0.0385
Valid loss recons: 0.02849
Valid loss KL-D: 10.00849
Epoch 30 Iter 2544: loss 0.03845. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.39it/s] Epoch 31 Iter 2544: loss 0.03408. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Train loss: 0.03869
Valid loss: 0.03851
Valid loss recons: 0.02832
Valid loss KL-D: 10.19438
Epoch 32 Iter 2544: loss 0.03660. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.33it/s] Epoch 33 Iter 2544: loss 0.04607. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.43it/s]
Train loss: 0.0386
Valid loss: 0.03849
Valid loss recons: 0.02827
Valid loss KL-D: 10.22045
Epoch 34 Iter 2544: loss 0.04096. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.24it/s] Epoch 35 Iter 2544: loss 0.03882. : 100%|███████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.48it/s]
Train loss: 0.03854
Valid loss: 0.03828
Valid loss recons: 0.02809
Valid loss KL-D: 10.19361
Epoch 36 Iter 2544: loss 0.03771. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s] Epoch 37 Iter 2544: loss 0.03488. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s]
Train loss: 0.03849
Valid loss: 0.03835
Valid loss recons: 0.02809
Valid loss KL-D: 10.25628
Epoch 38 Iter 2544: loss 0.03541. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s] Epoch 39 Iter 2544: loss 0.03638. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.27it/s]
Train loss: 0.03845
Valid loss: 0.03826
Valid loss recons: 0.02804
Valid loss KL-D: 10.21794
Epoch 40 Iter 2544: loss 0.03822. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.41it/s] Epoch 41 Iter 2544: loss 0.04498. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s]
Train loss: 0.03843
Valid loss: 0.03815
Valid loss recons: 0.02773
Valid loss KL-D: 10.42156
Epoch 42 Iter 2544: loss 0.03487. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s] Epoch 43 Iter 2544: loss 0.03993. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s]
Train loss: 0.03841
Valid loss: 0.03811
Valid loss recons: 0.02763
Valid loss KL-D: 10.48405
Epoch 44 Iter 2544: loss 0.04279. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s] Epoch 45 Iter 2544: loss 0.03781. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.43it/s]
Train loss: 0.03835
Valid loss: 0.03827
Valid loss recons: 0.02803
Valid loss KL-D: 10.23958
Epoch 46 Iter 2544: loss 0.04364. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.28it/s] Epoch 47 Iter 2544: loss 0.03749. : 100%|███████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.49it/s]
Train loss: 0.03834
Valid loss: 0.03816
Valid loss recons: 0.02794
Valid loss KL-D: 10.22188
Epoch 48 Iter 2544: loss 0.03406. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.29it/s] Epoch 49 Iter 2544: loss 0.04041. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.35it/s]
Train loss: 0.03831
Valid loss: 0.03822
Valid loss recons: 0.02795
Valid loss KL-D: 10.27694
Epoch 50 Iter 2544: loss 0.04041. : 100%|███████████████████████████████████████████| 2544/2544 [01:45<00:00, 24.19it/s]
Train loss: 0.0383
Valid loss: 0.03812
Valid loss recons: 0.02782
Valid loss KL-D: 10.30331
Training completed
savepath = "/experiments/experiment_40120242"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = vae_loss_function
optimizer = torch.optim.Adam(resnetvae.parameters(), lr=1e-3)
resnetvae = ResNetVAE().to(device)
resnetvae, optimizer, epoch, stats = load_model(resnetvae, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120242/models/checkpoint_epoch_48.pth')
utils.plot_loss_epoch(stats['train_loss'][:49], stats['valid_loss'])
In both cases (KLD lambda/weight higher and lower), training loss value hardly changes except the first epoch. Becauseof that we used learning rate scheduler to decrease the learning rate after 4 epochs, when a significant change in validation loss wasn't noticed.
# Generate more images
with torch.no_grad():
for i in range(5):
z = torch.randn(64, 512).to(device)
sample = resnetvae.decoder(z)
recons = sample.view(64, 3, 128, 128).cpu()
fig, axes = plt.subplots(1, 10, figsize=(128, 128)) # Adjust figsize as needed
for i in range(10):
img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
axes[i].imshow(img)
axes[i].axis('off') # Turn off axis labels for clarity
plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)
with torch.no_grad():
sample, _ = resnetvae(test_data)
recons = sample.view(batch_size, 3, 128, 128).cpu()
fig, axes = plt.subplots(2, 10, figsize=(128, 128)) # Adjust figsize as needed
for i in range(10):
img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
# test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
axes[0][i].imshow(test_img)
axes[0][i].axis('off')
axes[1][i].imshow(img)
axes[1][i].axis('off') # Turn off axis labels for clarity
plt.tight_layout()
plt.show()
Manipulating kl divergence lambda value can lead to vastly different results. When KL Divergence lambda value was equal to 1e-4, faces were more diverse than when the value was equal to 1e-3. A more encompassing discussion of this issue can be found in the appended thread: https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational-auto-encoder. Basically, "higher values (of the KL divergence parameter) give a more structured latent space at the cost of poorer reconstruction, and lower values give better reconstruction with a less structured latent space" [1]. This makes sense and can be perfectly seen here. In the end, reconstructions created by the first model, were after all better, simply because of the better reconstruction. The thing we learned is, that a proper KL divergence weighting value must be chosen in order to achieve satisfying results.
We use the implementation of torchmetrics to calculate the frechet inception distance.
import torchmetrics
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=64, normalize = True)
# Torchmetrics implementation does not use cuda for some reason
device = 'cpu'
cvae = ConvVAE()
cvae, optimizer, epoch, stats = load_model(cvae, optimizer, '/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_8/models/checkpoint_epoch_48.pth')
cvae = cvae.to(device)
test_dataset = torchvision.datasets.CelebA(root='./data', split='test',
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size = 64)
progress_bar = tqdm(enumerate(testloader), total=len(testloader))
for i, (test_data, _) in progress_bar:
test_data = test_data.to(device)
z = torch.randn(64, 200).to(device)
sample = cvae.decoder(z)
fid.update(sample, real=False)
fid.update(test_data, real=True)
fid.compute()
Files already downloaded and verified
100%|██████████| 312/312 [13:08<00:00, 2.53s/it]
tensor(5.0324)
print("FID for Vanilla ConvVAE (without pretrained encoder) is equal to 5.0324")
FID for Vanilla ConvVAE (without pretrained encoder) is equal to 5.0324
fid = FrechetInceptionDistance(feature=64, normalize = True)
device = 'cpu'
resnetvae_KL_lower = ResNetVAE(512).to(device)
optimizer = torch.optim.Adam(resnetvae_KL_lower.parameters(), lr=1e-3)
resnetvae_KL_lower, optimizer, epoch, stats = load_model(resnetvae_KL_lower, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120241/models/checkpoint_epoch_48.pth')
progress_bar = tqdm(enumerate(testloader), total=len(testloader))
for i, (test_data, _) in progress_bar:
test_data = test_data.to(device)
z = torch.randn(64, 512).to(device)
sample = resnetvae_KL_lower.decoder(z)
fid.update(sample, real=False)
fid.update(test_data, real=True)
fid.compute()
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 312/312 [12:35<00:00, 2.42s/it]
tensor(4.4890)
print("FID for ResNetVAE with KLD_lambda = 1e-4 is equal to 4.4890")
FID for ResNetVAE with KLD_lambda = 1e-4 is equal to 4.4890
fid = FrechetInceptionDistance(feature=64, normalize = True)
resnetvae_KL_higher = ResNetVAE(512).to(device)
optimizer = torch.optim.Adam(resnetvae_KL_higher.parameters(), lr=1e-3)
resnetvae_KL_higher, optimizer, epoch, stats = load_model(resnetvae_KL_higher, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120242/models/checkpoint_epoch_48.pth')
progress_bar = tqdm(enumerate(testloader), total=len(testloader))
for i, (test_data, _) in progress_bar:
test_data = test_data.to(device)
z = torch.randn(64, 512).to(device)
sample = resnetvae_KL_higher.decoder(z)
fid.update(sample, real=False)
fid.update(test_data, real=True)
fid.compute()
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 312/312 [12:44<00:00, 2.45s/it]
tensor(5.8770)
print("FID for ResNetVAE with KLD_lambda = 1e-3 is equal to 5.8770")
FID for ResNetVAE with KLD_lambda = 1e-3 is equal to 5.8770
FID is a metric for quantifying the realism and diversity of generated images. As we expected, the lowest FID value was achieved by ResNetVAE with a lower KLD weight. What is interesting, is that ConvVAE achieved better results that ResNetVAE with a higher KLD weight. In the end, we can empirically confirm those numbers just by looking at the generated images. The best looking ones are generated by ResNetVAE model with lower KLD weight and having that weight be higher, diminished the quality of the images. Nevertheless, we consider both ResNetVAE models to have higher quality generation than ConvVAE. Both ResNetVAE faces look less blocky and more natural than the ones of our ConvVAE. This holds especially for the lower KLD value model, even though the backgrounds are less natural and the faces are sometimes cut off. The reconstruction quality is similar. The ranking stays the same here. Especially the ResNet model with the lower KLD value shows more detail and quality for the reasons discussed above.
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import jaccard_score
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, fcluster
def jaccard_similarity(x, y):
return np.sum(np.minimum(x, y)) / np.sum(np.maximum(x, y))
def assign_clusters(labels, n_clusters):
# Compute the pairwise distance matrix in the Jaccard metric
dists = pdist(labels, metric=jaccard_similarity)
# Compute the linkage matrix
linkage_matrix = linkage(squareform(dists), method='average')
# Assign clusters
clusters = fcluster(linkage_matrix, n_clusters, criterion='maxclust')
return clusters
def display_projections(points, labels, num_classes, ax):
# Ensure labels are a numpy array for convenience
labels = np.array(labels)
# Create a colormap
cmap = plt.get_cmap('RdYlGn', num_classes) # 20 classes
for i in range(num_classes): # For each class
# Find points belonging to this class
class_points = points[labels == i]
# Plot these points with the color corresponding to this class
ax.scatter(class_points[:, 0], class_points[:, 1], color=cmap(i), label=f'Group {i+1}')
# Add a colorbar
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=ax)
N = 2000
num_clusters = 40
imgs_flat, latents, labels = [], [], []
with torch.no_grad():
for imgs, lbls in testloader:
imgs = imgs.to(device)
_, (z, _, _) = cvae(imgs)
imgs_flat.append(imgs.cpu().view(imgs.shape[0],-1))
latents.append(z.cpu())
labels.append(lbls)
imgs_flat = np.concatenate(imgs_flat)
latents = np.concatenate(latents)
labels = np.concatenate(labels)
t, l = next(iter(testloader))
l.shape
torch.Size([64, 40])
clusters = assign_clusters(labels[:N], num_clusters)
pca_imgs
array([[-25.339378 , -12.632147 ],
[-48.634953 , -25.977364 ],
[-28.15862 , 2.8991623],
...,
[ 42.41321 , 23.036709 ],
[ -6.5465856, 17.496992 ],
[-46.728096 , -2.5804029]], dtype=float32)
pca_imgs = PCA(n_components=2).fit_transform(imgs_flat[:N])
pca_latents = PCA(n_components=2).fit_transform(latents[:N])
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(pca_imgs[:N], clusters,num_classes = num_clusters, ax = ax[0])
ax[0].set_title("PCA Proj. of Images")
display_projections(pca_latents[:N], clusters, num_classes = num_clusters, ax=ax[1])
ax[1].set_title("Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True, ncol=5)
plt.show()
N = 2000
tsne_imgs = TSNE(n_components=2).fit_transform(imgs_flat[:N])
tsne_latents = TSNE(n_components=2).fit_transform(latents[:N])
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(tsne_imgs[:N], clusters[:N], num_classes = num_clusters, ax=ax[0])
ax[0].set_title("T-SNE Proj. of Images")
display_projections(tsne_latents[:N], clusters[:N],num_classes = num_clusters, ax=ax[1])
ax[1].set_title("T-SNE Proj. of Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True, ncol=5)
plt.show()
imgs_flat, latents, labels = [], [], []
with torch.no_grad():
for imgs, lbls in testloader:
imgs = imgs.to(device)
_, (z, _, _) = resnetvae_KL_lower(imgs)
imgs_flat.append(imgs.cpu().view(imgs.shape[0],-1))
latents.append(z.cpu())
labels.append(lbls)
imgs_flat = np.concatenate(imgs_flat)
latents = np.concatenate(latents)
labels = np.concatenate(labels)
def jaccard_similarity(x, y):
return np.sum(np.minimum(x, y)) / np.sum(np.maximum(x, y))
def assign_clusters(labels, n_clusters):
# Compute the pairwise distance matrix in the Jaccard metric
dists = pdist(labels, metric=jaccard_similarity)
# Compute the linkage matrix
linkage_matrix = linkage(squareform(dists), method='average')
# Assign clusters
clusters = fcluster(linkage_matrix, n_clusters, criterion='maxclust')
return clusters
N = 2000
num_clusters = 40
clusters = assign_clusters(labels[:N], num_clusters)
pca_imgs = PCA(n_components=2).fit_transform(imgs_flat[:N])
pca_latents = PCA(n_components=2).fit_transform(latents[:N])
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(pca_imgs[:N], clusters,num_classes = num_clusters, ax = ax[0])
#display_projections(pca_imgs[:N ], labels[:N], ax=ax[0])
ax[0].set_title("PCA Proj. of Images")
display_projections(pca_latents[:N], clusters, num_classes = num_clusters, ax=ax[1])
ax[1].set_title("Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True, ncol=5)
plt.show()
N = 2000
tsne_imgs = TSNE(n_components=2).fit_transform(imgs_flat[:N])
tsne_latents = TSNE(n_components=2).fit_transform(latents[:N])
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(tsne_imgs[:N], clusters[:N], num_classes = num_clusters, ax=ax[0])
ax[0].set_title("T-SNE Proj. of Images")
display_projections(tsne_latents[:N], clusters[:N],num_classes = num_clusters, ax=ax[1])
ax[1].set_title("T-SNE Proj. of Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True, ncol=5)
plt.show()
imgs_flat, latents, labels = [], [], []
with torch.no_grad():
for imgs, lbls in testloader:
imgs = imgs.to(device)
_, (z, _, _) = resnetvae_KL_higher(imgs)
imgs_flat.append(imgs.cpu().view(imgs.shape[0],-1))
latents.append(z.cpu())
labels.append(lbls)
imgs_flat = np.concatenate(imgs_flat)
latents = np.concatenate(latents)
labels = np.concatenate(labels)
clusters = assign_clusters(labels[:N], num_clusters) # repeat the steps
pca_imgs = PCA(n_components=2).fit_transform(imgs_flat[:N])
pca_latents = PCA(n_components=2).fit_transform(latents[:N])
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(pca_imgs[:N], clusters,num_classes = num_clusters, ax = ax[0])
ax[0].set_title("PCA Proj. of Images")
display_projections(pca_latents[:N], clusters, num_classes = num_clusters, ax=ax[1])
ax[1].set_title("Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True, ncol=5)
plt.show()
N = 2000
tsne_imgs = TSNE(n_components=2).fit_transform(imgs_flat[:N])
tsne_latents = TSNE(n_components=2).fit_transform(latents[:N])
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(tsne_imgs[:N], clusters[:N], num_classes = num_clusters, ax=ax[0])
ax[0].set_title("T-SNE Proj. of Images")
display_projections(tsne_latents[:N], clusters[:N],num_classes = num_clusters, ax=ax[1])
ax[1].set_title("T-SNE Proj. of Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True, ncol=5)
plt.show()
Given the large number of identities (labels), we tried to cluster the images ourselves.
In our problem context, we are dealing with 40-dimensional characteristic vectors, for instance, [1, 0, 0, 1, 0 ,1, …]. Each index in this vector represents a specific attribute, such as whether the individual is smiling (first index), or if the individual has long hair (second index), and so on.
The Jaccard similarity is particularly effective in this scenario because it calculates the similarity or diversity between two sets. This is done by comparing the intersection and the union of the sets, as illustrated in the attached diagram. Consequently, individuals who share a greater number of characteristics will have a higher similarity score. This makes the Jaccard similarity a powerful tool for our problem.
That's why we used Jaccard similarity to form clusters i.e. points, which attribute vectors more similar to each other in jaccard similarity, were put in one group/cluster/class.
Firstly, we observe that while it seems natural to us to cluster faces on striking features, the model considers more information and thus considers different images to be close together. This is intuitive, considering how large of a portion the background of an image usually has. Secondly, we observe that this disentanglement of the clusters transfers to the latent space. For similar reasons this is intuitive. If a general disentanglement of clusters in the projection moving to the latent space occurs requires more reserarch and more precise clustering.
Disregarding clusters, one can observe that a general structure of the latent space is hard to discern. Thus also making analysis difficult. We do observe however that our latent space mirrors the shape of a normal distribution, which makes sense given we impose this assumption onto our model.
device = 'cpu'
cvae = ConvVAE().to(device)
optimizer = torch.optim.Adam(cvae.parameters(), lr=3e-4)
cvae, optimizer, epoch, stats = load_model(cvae, optimizer, '/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_8/models/checkpoint_epoch_48.pth')
resnetvae_KL_lower = ResNetVAE().to(device)
optimizer = torch.optim.Adam(resnetvae_KL_lower.parameters(), lr=1e-3)
resnetvae_KL_lower, optimizer, epoch, stats = load_model(resnetvae_KL_lower, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120241/models/checkpoint_epoch_48.pth')
resnetvae_KL_higher = ResNetVAE().to(device)
optimizer = torch.optim.Adam(resnetvae_KL_higher.parameters(), lr=1e-3)
resnetvae_KL_higher, optimizer, epoch, stats = load_model(resnetvae_KL_higher, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120242/models/checkpoint_epoch_48.pth')
@torch.no_grad()
def plot_reconstructed(model, test_imgs, N=25):
# pdb.set_trace()
model = model.eval()
_,(z,_, _) = model(test_imgs)
z1, z2, z3, z4 = z
# checkpoints = [0,0.25,0.5,0.75,1]
checkpoints = [0,0.125,0.25,0.375,0.5,0.625,0.75,0.875,1]
interpolated_latent_vectors_LHS = [lam * z3 + (1-lam) * z2 for lam in checkpoints]
interpolated_latent_vectors_RHS = [lam * z4 + (1-lam) * z1 for lam in checkpoints]
for LHS,RHS in zip(interpolated_latent_vectors_LHS, interpolated_latent_vectors_RHS):
interpolated_row = [lam * LHS + (1-lam) * RHS for lam in checkpoints]
fig, ax = plt.subplots(1, len(checkpoints), figsize=(27,3), sharey=True, sharex=True)
for col,img in enumerate(interpolated_row):
ax[col].axis('off')
recon = model.decoder(img.view(1, img.size(0)))
recon = recon.view(3, 128, 128).cpu()
img = recon.detach().numpy().reshape(3, 128, 128).transpose(1, 2, 0)
maxValue = np.amax(img)
minValue = np.amin(img)
img = np.clip(img, 0, 1)
ax[col].imshow(img)
fig.tight_layout(pad=0, h_pad=0)
plot_reconstructed(resnetvae_KL_lower, test_data[0:4])
plot_reconstructed(resnetvae_KL_higher, test_data[0:4])
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)
plot_reconstructed(cvae, test_data[0:4])
Most notable face interpolation is produced by ResNetVAE model (lower KLD weight). There, change of faces can actually be seen. Also, the quality of reconstructed images is the best there. Second place belongs to ConvVAE, which interpolation consists of faces that change only a little bit, but one could argue that it still looks better that ResNetVAE with higher KLD weight, where it can be clearly seen that the there's hardly any interpolation, because the face remains the same in all the interpolated images.
def save_model(model, optimizer, epoch, stats, exp_no = 7):
""" Saving model checkpoint """
if(not os.path.exists("experiments/experiment_"+str(exp_no)+"/models")):
os.makedirs("experiments/experiment_"+str(exp_no)+"/models")
savepath = "experiments/experiment_"+str(exp_no)+f"/models/checkpoint_epoch_{epoch}.pth"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'stats': stats
}, savepath)
return
def load_model(model, optimizer, savepath):
""" Loading pretrained checkpoint """
checkpoint = torch.load(savepath)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint["epoch"]
stats = checkpoint["stats"]
return model, optimizer, epoch, stats
def train_epoch(model, train_loader, optimizer, criterion, epoch, device):
""" Training a model for one epoch """
loss_list = []
recons_loss = []
vae_loss = []
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
for i, (images, _) in progress_bar:
images = images.to(device)
# Clear gradients w.r.t. parameters
optimizer.zero_grad()
# Forward pass
recons, quant_loss = model(images)
# Calculate Loss
loss, mse = criterion(recons, quant_loss, images)
loss_list.append(loss.item())
recons_loss.append(mse.item())
# Getting gradients w.r.t. parameters
loss.backward()
# Updating parameters
optimizer.step()
progress_bar.set_description(f"Epoch {epoch+1} Iter {i+1}: loss {loss.item():.5f}. ")
mean_loss = np.mean(loss_list)
return mean_loss, loss_list
@torch.no_grad()
def eval_model(model, eval_loader, criterion, device, epoch=None, savefig=False, savepath="", writer=None):
""" Evaluating the model for either validation or test """
loss_list = []
recons_loss = []
for i, (images, _) in enumerate(eval_loader):
images = images.to(device)
# Forward pass
recons, quant_loss = model(images)
loss, mse = criterion(recons, quant_loss, images)
loss_list.append(loss.item())
recons_loss.append(mse.item())
if(i==0 and savefig):
save_image(recons[:64].cpu(), os.path.join(savepath, f"recons{epoch}.png"))
# Total correct predictions and loss
loss = np.mean(loss_list)
recons_loss = np.mean(recons_loss)
return loss, recons_loss
def train_model(model, optimizer, scheduler, criterion, train_loader, valid_loader,
num_epochs, savepath, writer, save_frequency=2):
""" Training a model for a given number of epochs"""
train_loss = []
val_loss = []
val_loss_recons = []
val_loss_kld = []
loss_iters = []
for epoch in range(num_epochs):
# validation epoch
model.eval() # important for dropout and batch norms
log_epoch = (epoch % save_frequency == 0 or epoch == num_epochs - 1)
loss, recons_loss= eval_model(
model=model, eval_loader=valid_loader, criterion=criterion,
device=device, epoch=epoch, savefig=log_epoch, savepath=savepath,
writer=writer
)
val_loss.append(loss)
val_loss_recons.append(recons_loss)
# training epoch
model.train() # important for dropout and batch norms
mean_loss, cur_loss_iters = train_epoch(
model=model, train_loader=train_loader, optimizer=optimizer,
criterion=criterion, epoch=epoch, device=device
)
# PLATEAU SCHEDULER
scheduler.step(val_loss[-1])
train_loss.append(mean_loss)
loss_iters = loss_iters + cur_loss_iters
if(epoch % save_frequency == 0):
stats = {
"train_loss": train_loss,
"valid_loss": val_loss,
"loss_iters": loss_iters
}
save_model(model=model, optimizer=optimizer, epoch=epoch, stats=stats)
if(log_epoch):
print(f" Train loss: {round(mean_loss, 5)}")
print(f" Valid loss: {round(loss, 5)}")
print(f" Valid loss recons: {round(val_loss_recons[-1], 5)}")
print(f"Training completed")
return train_loss, val_loss, loss_iters, val_loss_recons
# Code was taken from: https://colab.research.google.com/github/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
# Note, that the commitment_cost is the beta parameter from the paper
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(VectorQuantizer, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
self._commitment_cost = commitment_cost
def forward(self, inputs):
# convert inputs from BCHW -> BHWC
print("input Shape", inputs.shape)
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_input = inputs.view(-1, self._embedding_dim)
# Calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self._embedding.weight.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
print("Encodings Shape", encodings.shape)
encodings.scatter_(1, encoding_indices, 1)
print(encodings.shape)
print("encodings are one hots", encodings[1].sum())
# Quantize and unflatten
print("quantized shape", torch.matmul(encodings, self._embedding.weight).shape)
quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self._commitment_cost * e_latent_loss
quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# convert quantized from BHWC -> BCHW
return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
class VQVAE(nn.Module):
def __init__(self):
super(VQVAE, self).__init__()
# Define Convolutional Encoders
self.encoder = nn.Sequential(
nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 4, stride = 2, padding = 1),
nn.ReLU(),
nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 4, stride = 2, padding = 1),
nn.ReLU(),
nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 4, stride = 2, padding = 1),
nn.ReLU(),
)
self.VQ = VectorQuantizer(embedding_dim=64, num_embeddings=512, commitment_cost=0.25)
# Define decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = 4, stride = 2, padding = 1),
nn.ReLU(),
nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = 4, stride = 2, padding = 1),
nn.Tanh(),
nn.ConvTranspose2d(in_channels = 16, out_channels = 3, kernel_size = 4, stride = 2, padding = 1),
)
def forward(self, x):
x = self.encoder(x)
q_loss, x, ppl, encodings = self.VQ(x)
out = self.decoder(x)
return out, q_loss
def vqvae_loss_function(recons, quantize_losses, target):
recons_loss = F.mse_loss(recons, target)
loss = recons_loss + quantize_losses
return loss, quantize_losses
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vqvae = VQVAE()
criterion = vqvae_loss_function
optimizer = torch.optim.Adam(vqvae.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 3, factor = 0.5, verbose = True)
vqvae = vqvae.to(device)
vqvae, optimizer, epoch, stats = load_model(vqvae, optimizer, '/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_7/models/checkpoint_epoch_18.pth')
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)
with torch.no_grad():
sample, _ = vqvae(test_data)
recons = sample.view(batch_size, 3, 128, 128).cpu()
fig, axes = plt.subplots(2, 10, figsize=(128, 128)) # Adjust figsize as needed
for i in range(10):
img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
# test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
axes[0][i].imshow(test_img)
axes[0][i].axis('off')
axes[1][i].imshow(img)
axes[1][i].axis('off') # Turn off axis labels for clarity
plt.tight_layout()
plt.show()
/home/user/lschulze/anaconda3/envs/lab/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True). warnings.warn( /home/user/lschulze/anaconda3/envs/lab/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True). warnings.warn(
input Shape torch.Size([64, 64, 16, 16]) Encodings Shape torch.Size([16384, 512]) torch.Size([16384, 512]) encodings are one hots tensor(1., device='cuda:0') quantized shape torch.Size([16384, 64])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
The reconstructions of our VQVAE look very nice and very detailed. Comparing the reconstructions to the ones above those rank clearly the highest, showing the highest amount of detail. Unfortunately, we have some visual artifacts in our reconstructed images. However, given our prior findings, we are very confident that a more complex encoder/decoder would completely resolve the issue.
def choose_rows_uniformly(matrix, num_rows_to_choose):
# Get the total number of rows in the matrix
total_rows = matrix.size(0)
# Generate random indices for selecting rows
random_indices = torch.randint(0, total_rows, (num_rows_to_choose,), dtype=torch.long)
# Use the random indices to select rows from the matrix
selected_rows = matrix[random_indices]
return selected_rows
# Example usage:
# Assuming 'your_matrix' is a PyTorch tensor with shape (num_rows, num_columns)
your_matrix = vqvae.VQ._embedding.weight
num_rows_to_choose = 16384
selected_rows = choose_rows_uniformly(your_matrix, num_rows_to_choose)
print("Selected Rows:")
print(selected_rows.shape)
Selected Rows: torch.Size([16384, 64])
with torch.no_grad():
z = selected_rows.reshape(64, 64, 16, 16).to(device)
# z = torch.rand(64, 64, 16, 16).to(device)
# z = vqvae.encoder(z)
# _, z, _ , _ = vqvae.VQ(z)
sample = vqvae.decoder(z)
recons = sample.view(64, 3, 128, 128).cpu()
torch.Size([64, 3, 128, 128])
torch.Size([64, 3, 128, 128])
fig, axes = plt.subplots(1, 10, figsize=(128, 128)) # Adjust figsize as needed
for i in range(10):
img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
axes[i].imshow(img)
axes[i].axis('off') # Turn off axis labels for clarity
plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
The generation of images using VQVAE was not successful unfortunately, clearly ranking last when compared to the above results. We consider the reason for this to be uniform prior being too weak of an assumption. Considering, that during reconstruction our encoded vectors have a very particular structure based on the input, this seems intuitive. Using a probability distribution given by an RNN or LSTM, trained on existent latent representations, would therefore be the better choice.
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=64, normalize = True)
# Torchmetrics implementation does not use cuda for some reason
device = 'cpu'
vqvae = vqvae.to(device)
test_dataset = torchvision.datasets.CelebA(root='./data', split='test',
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size = 64)
progress_bar = tqdm(enumerate(testloader), total=len(testloader))
for i, (test_data, _) in progress_bar:
test_data = test_data.to(device)
selected_rows = choose_rows_uniformly(your_matrix, num_rows_to_choose)
z = selected_rows.reshape(64, 64, 16, 16).to(device)
sample = vqvae.decoder(z)
fid.update(sample, real=False)
fid.update(test_data, real=True)
fid.compute()
Files already downloaded and verified
100%|██████████| 312/312 [08:57<00:00, 1.72s/it]
tensor(85.4611)
The FID score of the VQVAE results are the highest compared to all prior models. This is expected given the viewed reconstructions